#!/usr/bin/env python
# -*- coding: utf-8 -*-

""" By Martin Senande-Rivera
	For Spatial and temporal expansion of global wildland fire activity in response to climate change """

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from mpl_toolkits.basemap import Basemap, maskoceans
from matplotlib.cm import get_cmap
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap

FC_folder ='../2-Classification/'
BA_folder ='../../DATA/GFED4/'
BA_mean_folder ='../0-Variables_plots/'

FC_months = xr.open_dataset(FC_folder+'FC_array.nc')
FC_months = FC_months.FC
FC_months = xr.where(FC_months>0,1,FC_months)
FC_years = FC_months.groupby('time.year').max(dim='time')
FC_years = xr.where(FC_years>0,1,FC_years)
FC_total = xr.open_dataset(FC_folder+'FC.nc')
FC = FC_total.FC

BA_months = xr.open_dataset(BA_folder+'GFED4_BA_months.nc')
BA_months = BA_months.BurnedArea
BA_years = xr.open_dataset(BA_folder+'GFED4_BA_years.nc')
BA_years = BA_years.BurnedArea
BA_mean = xr.open_dataset(BA_folder+'GFED4_BA_a_mean.nc')
BA_mean = BA_mean.BurnedArea

ruta_kg='../1-Threshold_selection/'  # KG classification data path
ds_kg = xr.open_dataset(ruta_kg+'KG_class.nc')
KG = ds_kg.KG 						# KG classification

##### Classified BA and points #####
####################################

BAper_r = BA_mean.where((FC==11.) | (FC==21.) | (FC==31.) | (FC==41.)).sum()/BA_mean.sum()*100.
BAper_KG1_r = BA_mean.where((KG==1) & (FC==11.)).sum()/BA_mean.where(KG==1).sum()*100.
BAper_KG2_r = BA_mean.where((KG==2) & (FC==21.)).sum()/BA_mean.where(KG==2).sum()*100.
BAper_KG3_r = BA_mean.where((KG==3) & (FC==31.)).sum()/BA_mean.where(KG==3).sum()*100.
BAper_KG4_r = BA_mean.where((KG==4) & (FC==41.)).sum()/BA_mean.where(KG==4).sum()*100.

BAper_o = BA_mean.where((FC==12.) | (FC==22.) | (FC==32.) | (FC==42.)).sum()/BA_mean.sum()*100.
BAper_KG1_o = BA_mean.where((KG==1) & (FC==12.)).sum()/BA_mean.where(KG==1).sum()*100.
BAper_KG2_o = BA_mean.where((KG==2) & (FC==22.)).sum()/BA_mean.where(KG==2).sum()*100.
BAper_KG3_o = BA_mean.where((KG==3) & (FC==32.)).sum()/BA_mean.where(KG==3).sum()*100.
BAper_KG4_o = BA_mean.where((KG==4) & (FC==42.)).sum()/BA_mean.where(KG==4).sum()*100.

BAper_i = BA_mean.where((FC==13.) | (FC==23.) | (FC==33.) | (FC==43.)).sum()/BA_mean.sum()*100.
BAper_KG1_i = BA_mean.where((KG==1) & (FC==13.)).sum()/BA_mean.where(KG==1).sum()*100.
BAper_KG2_i = BA_mean.where((KG==2) & (FC==23.)).sum()/BA_mean.where(KG==2).sum()*100.
BAper_KG3_i = BA_mean.where((KG==3) & (FC==33.)).sum()/BA_mean.where(KG==3).sum()*100.
BAper_KG4_i = BA_mean.where((KG==4) & (FC==43.)).sum()/BA_mean.where(KG==4).sum()*100.

Per_r = np.array([BAper_r,BAper_KG1_r,BAper_KG2_r,BAper_KG3_r,BAper_KG4_r])
Per_o = np.array([BAper_o,BAper_KG1_o,BAper_KG2_o,BAper_KG3_o,BAper_KG4_o])
Per_i = np.array([BAper_i,BAper_KG1_i,BAper_KG2_i,BAper_KG3_i,BAper_KG4_i])

Area_points = xr.full_like(FC,fill_value=0.)
Area_points.values[:] = np.transpose(np.tile(np.cos(Area_points.lat*np.pi/180),(1440,1)))

Aper_r = Area_points.where((BA_mean>0) & ((FC==11.) | (FC==21.) | (FC==31.) | (FC==41.))).sum()/Area_points.where((BA_mean>0) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()*100.
Aper_KG1_r = Area_points.where((BA_mean>0) & (KG==1) & (FC==11.)).sum()/Area_points.where((BA_mean>0) & (KG==1)).sum()*100.
Aper_KG2_r = Area_points.where((BA_mean>0) & (KG==2) & (FC==21.)).sum()/Area_points.where((BA_mean>0) & (KG==2)).sum()*100.
Aper_KG3_r = Area_points.where((BA_mean>0) & (KG==3) & (FC==31.)).sum()/Area_points.where((BA_mean>0) & (KG==3)).sum()*100.
Aper_KG4_r = Area_points.where((BA_mean>0) & (KG==4) & (FC==41.)).sum()/Area_points.where((BA_mean>0) & (KG==4)).sum()*100.

Aper_o = Area_points.where((BA_mean>0) & ((FC==12.) | (FC==22.) | (FC==32.) | (FC==42.))).sum()/Area_points.where((BA_mean>0) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()*100.
Aper_KG1_o = Area_points.where((BA_mean>0) & (KG==1) & (FC==12.)).sum()/Area_points.where((BA_mean>0) & (KG==1)).sum()*100.
Aper_KG2_o = Area_points.where((BA_mean>0) & (KG==2) & (FC==22.)).sum()/Area_points.where((BA_mean>0) & (KG==2)).sum()*100.
Aper_KG3_o = Area_points.where((BA_mean>0) & (KG==3) & (FC==32.)).sum()/Area_points.where((BA_mean>0) & (KG==3)).sum()*100.
Aper_KG4_o = Area_points.where((BA_mean>0) & (KG==4) & (FC==42.)).sum()/Area_points.where((BA_mean>0) & (KG==4)).sum()*100.

Aper_i = Area_points.where((BA_mean>0) & ((FC==13.) | (FC==23.) | (FC==33.) | (FC==43.))).sum()/Area_points.where((BA_mean>0) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()*100.
Aper_KG1_i = Area_points.where((BA_mean>0) & (KG==1) & (FC==13.)).sum()/Area_points.where((BA_mean>0) & (KG==1)).sum()*100.
Aper_KG2_i = Area_points.where((BA_mean>0) & (KG==2) & (FC==23.)).sum()/Area_points.where((BA_mean>0) & (KG==2)).sum()*100.
Aper_KG3_i = Area_points.where((BA_mean>0) & (KG==3) & (FC==33.)).sum()/Area_points.where((BA_mean>0) & (KG==3)).sum()*100.
Aper_KG4_i = Area_points.where((BA_mean>0) & (KG==4) & (FC==43.)).sum()/Area_points.where((BA_mean>0) & (KG==4)).sum()*100.

Per2_r = np.array([Aper_r,Aper_KG1_r,Aper_KG2_r,Aper_KG3_r,Aper_KG4_r])
Per2_o = np.array([Aper_o,Aper_KG1_o,Aper_KG2_o,Aper_KG3_o,Aper_KG4_o])
Per2_i = np.array([Aper_i,Aper_KG1_i,Aper_KG2_i,Aper_KG3_i,Aper_KG4_i])

x = np.array([1,2,3,4,5])
x1 = np.array([0.85,1.85,2.85,3.85,4.85])
x2 = np.array([1.15,2.15,3.15,4.15,5.15])
width = 0.25
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111)
barlist1 = ax.bar(x1, Per_r, width)
barlist1[0].set_color('k')
barlist1[1].set_color('#145305')
barlist1[2].set_color('#cb6c00')
barlist1[3].set_color('#9f007f')
barlist1[4].set_color('#063555')
barlist2 = ax.bar(x1, Per_o, width, bottom=Per_r)
barlist2[0].set_color('grey')
barlist2[1].set_color('#218607')
barlist2[2].set_color('#ffc021')
barlist2[3].set_color('#ff6bcb')
barlist2[4].set_color('#0e84d2')
barlist3 = ax.bar(x1, Per_i, width, bottom=(Per_r+Per_o))
barlist3[0].set_color('lightgrey')
barlist3[1].set_color('#8ade75')
barlist3[2].set_color('#ffdb87')
barlist3[3].set_color('#f3c4ff')
barlist3[4].set_color('#81c0e9')
ax.text(x1[0], Per_r[0]+Per_o[0]+Per_i[0]-3., '{:.2f}%'.format(Per_r[0]+Per_o[0]+Per_i[0]), horizontalalignment='center',color='k', fontweight='bold')
ax.text(x1[1], Per_r[1]+Per_o[1]+Per_i[1]-3., '{:.2f}%'.format(Per_r[1]+Per_o[1]+Per_i[1]), horizontalalignment='center',color='#145305', fontweight='bold')
ax.text(x1[2], Per_r[2]+Per_o[2]+Per_i[2]-3., '{:.2f}%'.format(Per_r[2]+Per_o[2]+Per_i[2]), horizontalalignment='center',color='#cb6c00', fontweight='bold')
ax.text(x1[3], Per_r[3]+Per_o[3]+Per_i[3]-3., '{:.2f}%'.format(Per_r[3]+Per_o[3]+Per_i[3]), horizontalalignment='center',color='#a60808', fontweight='bold')
ax.text(x1[4], Per_r[4]+Per_o[4]+Per_i[4]-3., '{:.2f}%'.format(Per_r[4]+Per_o[4]+Per_i[4]), horizontalalignment='center',color='#063555', fontweight='bold')
barlist21 = ax.bar(x2, Per2_r, width)
barlist21[0].set_color('k')
barlist21[1].set_color('#145305')
barlist21[2].set_color('#cb6c00')
barlist21[3].set_color('#9f007f')
barlist21[4].set_color('#063555')
barlist22 = ax.bar(x2, Per2_o, width, bottom=Per2_r)
barlist22[0].set_color('grey')
barlist22[1].set_color('#218607')
barlist22[2].set_color('#ffc021')
barlist22[3].set_color('#ff6bcb')
barlist22[4].set_color('#0e84d2')
barlist23 = ax.bar(x2, Per2_i, width, bottom=(Per2_r+Per2_o))
barlist23[0].set_color('lightgrey')
barlist23[1].set_color('#8ade75')
barlist23[2].set_color('#ffdb87')
barlist23[3].set_color('#f3c4ff')
barlist23[4].set_color('#81c0e9')
ax.text(x2[0], Per2_r[0]+Per2_o[0]+Per2_i[0]-3., '{:.2f}%'.format(Per2_r[0]+Per2_o[0]+Per2_i[0]), horizontalalignment='center',color='k', fontweight='bold')
ax.text(x2[1], Per2_r[1]+Per2_o[1]+Per2_i[1]+2., '{:.2f}%'.format(Per2_r[1]+Per2_o[1]+Per2_i[1]), horizontalalignment='center',color='#145305', fontweight='bold')
ax.text(x2[2], Per2_r[2]+Per2_o[2]+Per2_i[2]-4., '{:.2f}%'.format(Per2_r[2]+Per2_o[2]+Per2_i[2]), horizontalalignment='center',color='#cb6c00', fontweight='bold')
ax.text(x2[3], Per2_r[3]+Per2_o[3]+Per2_i[3]-3., '{:.2f}%'.format(Per2_r[3]+Per2_o[3]+Per2_i[3]), horizontalalignment='center',color='#a60808', fontweight='bold')
ax.text(x2[4], Per2_r[4]+Per2_o[4]+Per2_i[4]-3., '{:.2f}%'.format(Per2_r[4]+Per2_o[4]+Per2_i[4]), horizontalalignment='center',color='#063555', fontweight='bold')
ax.set_ylim(0,100)
ax.set_xticks(x)
ax.set_xticklabels(('Global', 'Tropical', 'Arid', 'Temperate', 'Boreal'))
ax.legend(labels=['recurrent', 'occasional', 'infrequent'],loc='upper center', bbox_to_anchor=(0.5, 1.08), fancybox=False, shadow=False, ncol=3)
ax.set_ylabel('Classified (%)')
plt.grid(axis='y')
ratio=0.8
ax.set_aspect(1.0/ax.get_data_ratio()*ratio)
fig.savefig('./BAmean-Points_percentage.png',dpi=150)

##### ALL POINTS (Classified and not classified) #####
######################################################

# Array with classified BA>0ha points (1,2,3,4), unclassified BA=0ha points (0), classified 'BA=0ha points (5) and unclassified BA>0ha (6)
FC_val = FC.copy()
FC_val = xr.where((BA_mean==0) & (FC>0),5,FC_val)
FC_val = xr.where((BA_mean>0) & (FC==0),6,FC_val)
FC_val.to_netcdf('FC_val.nc')

Points_burn_class = Area_points.where((FC_val>10.) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()/Area_points.where((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.)).sum()*100.
Points_noburn_class = Area_points.where((FC_val==5.) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()/Area_points.where((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.)).sum()*100.
Points_burn_noclass = Area_points.where((FC_val==6.) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()/Area_points.where((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.)).sum()*100.
Points_noburn_noclass = Area_points.where((FC_val==0.) & ((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.))).sum()/Area_points.where((KG==1.) | (KG==2.) | (KG==3.) | (KG==4.)).sum()*100.

Points_KG1_burn_class = Area_points.where((FC_val>10.) & (KG==1.)).sum()/Area_points.where(KG==1.).sum()*100.
Points_KG1_noburn_class = Area_points.where((FC_val==5.) & (KG==1.)).sum()/Area_points.where(KG==1.).sum()*100.
Points_KG1_burn_noclass = Area_points.where((FC_val==6.) & (KG==1.)).sum()/Area_points.where(KG==1.).sum()*100.
Points_KG1_noburn_noclass = Area_points.where((FC_val==0.) & (KG==1.)).sum()/Area_points.where(KG==1.).sum()*100.

Points_KG2_burn_class = Area_points.where((FC_val>10.) & (KG==2.)).sum()/Area_points.where(KG==2.).sum()*100.
Points_KG2_noburn_class = Area_points.where((FC_val==5.) & (KG==2.)).sum()/Area_points.where(KG==2.).sum()*100.
Points_KG2_burn_noclass = Area_points.where((FC_val==6.) & (KG==2.)).sum()/Area_points.where(KG==2.).sum()*100.
Points_KG2_noburn_noclass = Area_points.where((FC_val==0.) & (KG==2.)).sum()/Area_points.where(KG==2.).sum()*100.

Points_KG3_burn_class = Area_points.where((FC_val>10.) & (KG==3.)).sum()/Area_points.where(KG==3.).sum()*100.
Points_KG3_noburn_class = Area_points.where((FC_val==5.) & (KG==3.)).sum()/Area_points.where(KG==3.).sum()*100.
Points_KG3_burn_noclass = Area_points.where((FC_val==6.) & (KG==3.)).sum()/Area_points.where(KG==3.).sum()*100.
Points_KG3_noburn_noclass = Area_points.where((FC_val==0.) & (KG==3.)).sum()/Area_points.where(KG==3.).sum()*100.

Points_KG4_burn_class = Area_points.where((FC_val>10.) & (KG==4.)).sum()/Area_points.where(KG==4.).sum()*100.
Points_KG4_noburn_class = Area_points.where((FC_val==5.) & (KG==4.)).sum()/Area_points.where(KG==4.).sum()*100.
Points_KG4_burn_noclass = Area_points.where((FC_val==6.) & (KG==4.)).sum()/Area_points.where(KG==4.).sum()*100.
Points_KG4_noburn_noclass = Area_points.where((FC_val==0.) & (KG==4.)).sum()/Area_points.where(KG==4.).sum()*100.


Per_burn_class = np.array([Points_burn_class,Points_KG1_burn_class,Points_KG2_burn_class,Points_KG3_burn_class,Points_KG4_burn_class])
Per_noburn_class = np.array([Points_noburn_class,Points_KG1_noburn_class,Points_KG2_noburn_class,Points_KG3_noburn_class,Points_KG4_noburn_class])
Per_burn_noclass = np.array([Points_burn_noclass,Points_KG1_burn_noclass,Points_KG2_burn_noclass,Points_KG3_burn_noclass,Points_KG4_burn_noclass])
Per_noburn_noclass = np.array([Points_noburn_noclass,Points_KG1_noburn_noclass,Points_KG2_noburn_noclass,Points_KG3_noburn_noclass,Points_KG4_noburn_noclass])

x = np.array([1,2,3,4,5])
width = 0.5
fig = plt.figure(figsize=(8,6))
ax = fig.add_subplot(111)
barlist1 = ax.bar(x, Per_burn_class, width)
barlist1[0].set_color('saddlebrown')
barlist1[1].set_color('#145305')
barlist1[2].set_color('#cb6c00')
barlist1[3].set_color('#9f007f')
barlist1[4].set_color('#063555')
barlist2 = ax.bar(x, Per_noburn_noclass, width, bottom=Per_burn_class, color='whitesmoke')
barlist3 = ax.bar(x, Per_noburn_class, width, bottom=(Per_burn_class+Per_noburn_noclass), color='k')
barlist4 = ax.bar(x, Per_burn_noclass, width, bottom=(Per_burn_class+Per_noburn_noclass+Per_noburn_class), color='grey')
ax.set_ylim(0,100)
ax.set_xticks(x)
ax.set_xticklabels(('Global', 'Tropical', 'Arid', 'Temperate', 'Boreal'))
ax.legend(labels=['BA>0ha | C', 'BA=0ha | NC', 'BA=0ha | C', 'BA>0ha | NC'],loc='upper center', bbox_to_anchor=(0.5, 1.08), fancybox=False, shadow=False, ncol=4)
ax.set_ylabel('Spatial points (%)')
plt.grid(axis='y')
ratio=0.8
ax.set_aspect(1.0/ax.get_data_ratio()*ratio)
fig.savefig('./Points_percentage_total.png',dpi=150)

print('\n')
print('Percentage of well classified points: {:.3f}%'.format(Points_burn_class.values+Points_noburn_noclass.values))
print('Percentage classified points with  BA=0ha: {:.3f}%'.format(Points_noburn_class.values))

Points_noburn_class_r = Area_points.where((FC_val==5.) & ((FC==11.) | (FC==21.) | (FC==31.) | (FC==41.))).sum()/Area_points.where((FC_val==5.)).sum()*100.
Points_noburn_class_o = Area_points.where((FC_val==5.) & ((FC==12.) | (FC==22.) | (FC==32.) | (FC==42.))).sum()/Area_points.where((FC_val==5.)).sum()*100.
Points_noburn_class_i = Area_points.where((FC_val==5.) & ((FC==13.) | (FC==23.) | (FC==33.) | (FC==43.))).sum()/Area_points.where((FC_val==5.)).sum()*100.
print('Percentage of classified points with  BA=0ha, but classified as recurrent fire-prone: {:.3f}%'.format(Points_noburn_class_r.values))
print('Percentage of classified points with  BA=0ha, but classified as occasional fire-prone: {:.3f}%'.format(Points_noburn_class_o.values))
print('Percentage of classified points with  BA=0ha, but classified as infrequent fire-prone: {:.3f}%'.format(Points_noburn_class_i.values))


### BA percentages Spatial ###
##############################

BA_mean_percentage = BA_mean.where(FC>=1.).sum()/BA_mean.sum()*100.

## KG 1 ##
BA_mean_percentage_KG1 = BA_mean.where((KG==1) & (FC>=1.)).sum()/BA_mean.where(KG==1).sum()*100.

## KG 2 ##
BA_mean_percentage_KG2 = BA_mean.where((KG==2) & (FC>=1.)).sum()/BA_mean.where(KG==2).sum()*100.

## KG 3 ##
BA_mean_percentage_KG3 = BA_mean.where((KG==3) & (FC>=1.)).sum()/BA_mean.where(KG==3).sum()*100.

## KG 4 ##
BA_mean_percentage_KG4 = BA_mean.where((KG==4) & (FC>=1.)).sum()/BA_mean.where(KG==4).sum()*100.

print('\n')
print('BAmean classified: {:.3f}%'.format(BA_mean_percentage.values))

print('\n KG1')
print('BAmean classified: {:.3f}%'.format(BA_mean_percentage_KG1.values))

print('\n KG2')
print('BAmean classified: {:.3f}%'.format(BA_mean_percentage_KG2.values))

print('\n KG3')
print('BAmean classified: {:.3f}%'.format(BA_mean_percentage_KG3.values))

print('\n KG4')
print('BAmean classified: {:.3f}%'.format(BA_mean_percentage_KG4.values))

### BA percentages Temporal ###
###############################

BA_years_percentage = BA_years.where(FC_years==1.).sum()/BA_years.where(FC>=1.).sum()*100.
BA_months_percentage = BA_months.where(FC_months==1.).sum()/BA_years.where(FC_years==1.).sum()*100.

print('\n')
print('Yearly BA classified: {:.3f}%'.format(BA_years_percentage.values))
print('Monthly BA classified: {:.3f}%'.format(BA_months_percentage.values))
